from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from agents import RunErrorDetails

from agency_swarm import Agent, GuardrailFunctionOutput, OutputGuardrailTripwireTriggered, ThreadManager
from agency_swarm.agent.core import AgencyContext


def _make_tripwire(
    agent_output: str,
    guidance: str,
    *,
    include_run_data: bool = True,
) -> OutputGuardrailTripwireTriggered:
    class _GuardrailObj:
        pass

    class _MockRunItem:
        def __init__(self, role: str, content: str):
            self.role = role
            self.content = content

        def to_input_item(self):
            return {"role": self.role, "content": self.content}

    guardrail_result = type(
        "_OutputGuardrailResult",
        (),
        {
            "agent_output": agent_output,
            "output": GuardrailFunctionOutput(output_info=guidance, tripwire_triggered=True),
            "guardrail": _GuardrailObj(),
        },
    )()

    # Create the exception with the guardrail_result
    exception = OutputGuardrailTripwireTriggered(guardrail_result)

    # Set the run_data on the exception - needed by _extract_guardrail_texts
    if include_run_data:
        exception.run_data = RunErrorDetails(
            input=[],
            new_items=[_MockRunItem("assistant", agent_output)],
            raw_responses=[],
            last_agent=None,
            context_wrapper=None,
            input_guardrail_results=[],
            output_guardrail_results=[],
        )

    return exception


@pytest.mark.asyncio
@patch("agency_swarm.agent.execution_helpers.Runner.run", new_callable=AsyncMock)
async def test_output_guardrail_retries_update_history(mock_runner_run):
    agent = Agent(name="RetryAgent", instructions="Test", validation_attempts=1)

    # Prepare minimal agency context to capture messages
    ctx = AgencyContext(agency_instance=None, thread_manager=ThreadManager(), subagents={})

    # First attempt trips, second returns a minimal RunResult-like object
    mock_runner_run.side_effect = [
        _make_tripwire(agent_output="BAD OUTPUT", guidance="ERROR: fix format"),
        MagicMock(new_items=[], final_output="GOOD"),
    ]

    # Execute
    res = await agent.get_response(message="What is openai?", agency_context=ctx)
    assert getattr(res, "final_output", None) == "GOOD"

    # Validate conversation history contains initial user, appended assistant, appended user guidance
    all_msgs = ctx.thread_manager.get_all_messages()
    # Extract role and content for clarity
    trio = [(m.get("role"), m.get("content")) for m in all_msgs]
    # Expect at least 3 messages; find the last three
    assert ("user", "What is openai?") in trio
    assert ("assistant", "BAD OUTPUT") in trio
    assert ("system", "ERROR: fix format") in trio

    # The guidance system message should be classified as an output guardrail error
    sys_msgs = [m for m in all_msgs if m.get("role") == "system"]
    assert sys_msgs and sys_msgs[-1].get("message_origin") == "output_guardrail_error"


@pytest.mark.asyncio
@patch("agency_swarm.agent.execution_helpers.Runner.run", new_callable=AsyncMock)
async def test_output_guardrail_retries_without_run_data(mock_runner_run):
    agent = Agent(name="RetryAgentNoRunData", instructions="Test", validation_attempts=1)
    ctx = AgencyContext(agency_instance=None, thread_manager=ThreadManager(), subagents={})

    mock_runner_run.side_effect = [
        _make_tripwire(agent_output="MALFORMED", guidance="Provide JSON", include_run_data=False),
        MagicMock(new_items=[], final_output="RECOVERED"),
    ]

    result = await agent.get_response(message="Fix this", agency_context=ctx)
    assert getattr(result, "final_output", None) == "RECOVERED"

    history = ctx.thread_manager.get_all_messages()
    contents = [(m.get("role"), m.get("content")) for m in history]
    assert ("assistant", "MALFORMED") in contents
    assert ("system", "Provide JSON") in contents


class _DummyStream:
    def __init__(self, events):
        self._events = events

    async def stream_events(self):
        for ev in self._events:
            yield ev

    def cancel(self):
        pass


class _SimpleEvent:
    def __init__(self, t: str):
        self.type = t


@pytest.mark.asyncio
@patch("agency_swarm.agent.execution_helpers.Runner.run_streamed")
async def test_output_guardrail_retries_streaming(mock_run_streamed):
    agent = Agent(name="RetryStreamAgent", instructions="Test", validation_attempts=1)
    ctx = AgencyContext(agency_instance=None, thread_manager=ThreadManager(), subagents={})

    # First call raises; second returns a dummy stream with one event
    mock_run_streamed.side_effect = [
        _make_tripwire(agent_output="STREAM BAD", guidance="ERROR: needs header"),
        _DummyStream([_SimpleEvent("run_item_stream_event")]),
    ]

    # Collect streamed events
    received = []
    async for ev in agent.get_response_stream(message="Hello", agency_context=ctx):
        received.append(ev)

    assert received, "expected events from second attempt"

    # The guidance user message should be in history
    msgs = ctx.thread_manager.get_all_messages()
    roles_contents = [(m.get("role"), m.get("content")) for m in msgs]
    assert ("system", "ERROR: needs header") in roles_contents
